from __future__ import annotations

import torch

from typing import Tuple

@torch.no_grad()
def flatten_update(update: dict[str, torch.Tensor]) -> Tuple[torch.Tensor, list[Tuple[str, torch.Size]]]:
    flat_param_updates = []
    structure = []
    for name, param_update in update.items():
        flat_param_updates.append(param_update.view(-1))
        structure.append((name, param_update.size()))
    flat_update = torch.cat(flat_param_updates)

    return flat_update, structure

@torch.no_grad()
def flatten_updates(updates: dict[str, dict[str, torch.Tensor]]) -> Tuple[torch.Tensor, dict]:
    flat_updates = []
    client_idxs = []
    update_structure = []
    for client_idx in updates:
        client_idxs.append(client_idx)
        update = updates[client_idx]
        flat_update, update_structure = flatten_update(update)
        flat_updates.append(flat_update)
    if len(flat_updates) == 0:
        flat_updates = torch.Tensor()
    else:
        flat_updates = torch.stack(flat_updates)
    structure = {
        'client_idxs': client_idxs,
        'update_structure': update_structure,
    }
    return flat_updates, structure

@torch.no_grad()
def unflatten_update(flat_update: torch.Tensor, structure: dict) -> dict[str, torch.Tensor]:
    update_structure: list[Tuple[str, torch.Size]] = structure['update_structure']
    update = {}
    tensor_idx = 0
    for name, size in update_structure:
        num = size.numel()
        update[name] = flat_update[tensor_idx : tensor_idx + num].view(size)
        tensor_idx += num
    return update